{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stable Model Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTES:  \n",
    "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
    "* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fastai\n",
    "from fastai import *\n",
    "from fastai.vision import *\n",
    "from fastai.callbacks.tensorboard import *\n",
    "from fastai.vision.gan import *\n",
    "from fasterai.generators import *\n",
    "from fasterai.critics import *\n",
    "from fasterai.dataset import *\n",
    "from fasterai.loss import *\n",
    "from fasterai.save import *\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from PIL import ImageFile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
    "path_hr = path\n",
    "path_lr = path/'bandw'\n",
    "\n",
    "proj_id = 'StableModel'\n",
    "\n",
    "gen_name = proj_id + '_gen'\n",
    "pre_gen_name = gen_name + '_0'\n",
    "crit_name = proj_id + '_crit'\n",
    "\n",
    "name_gen = proj_id + '_image_gen'\n",
    "path_gen = path/name_gen\n",
    "\n",
    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
    "\n",
    "nf_factor = 2\n",
    "pct_start = 1e-8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(bs:int, sz:int, keep_pct:float):\n",
    "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
    "                             random_seed=None, keep_pct=keep_pct)\n",
    "\n",
    "def get_crit_data(classes, bs, sz):\n",
    "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
    "    ll = src.label_from_folder(classes=classes)\n",
    "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
    "           .databunch(bs=bs).normalize(imagenet_stats))\n",
    "    return data\n",
    "\n",
    "def create_training_images(fn,i):\n",
    "    dest = path_lr/fn.relative_to(path_hr)\n",
    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
    "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
    "    img.save(dest)  \n",
    "    \n",
    "def save_preds(dl):\n",
    "    i=0\n",
    "    names = dl.dataset.items\n",
    "    \n",
    "    for b in dl:\n",
    "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
    "        for o in preds:\n",
    "            o.save(path_gen/names[i].name)\n",
    "            i += 1\n",
    "    \n",
    "def save_gen_images():\n",
    "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
    "    path_gen.mkdir(exist_ok=True)\n",
    "    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
    "    save_preds(data_gen.fix_dl)\n",
    "    PIL.Image.open(path_gen.ls()[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create black and white training images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Only runs if the directory isn't already created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not path_lr.exists():\n",
    "    il = ImageList.from_folder(path_hr)\n",
    "    parallel(create_training_images, il.items)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-train generator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE\n",
    "Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 64px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=88\n",
    "sz=64\n",
    "keep_pct=1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(1, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 128px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=20\n",
    "sz=128\n",
    "keep_pct=1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 192px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "sz=192\n",
    "keep_pct=0.50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repeatable GAN Cycle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE\n",
    "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_checkpoint_num = 0\n",
    "checkpoint_num = old_checkpoint_num + 1\n",
    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save Generated Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "sz=192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_gen_images()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pretrain Critic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if old_checkpoint_num == 0:\n",
    "    bs=64\n",
    "    sz=128\n",
    "    learn_gen=None\n",
    "    gc.collect()\n",
    "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
    "    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
    "    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
    "    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
    "    learn_critic.fit_one_cycle(6, 1e-3)\n",
    "    learn_critic.save(crit_old_checkpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=16\n",
    "sz=192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic.fit_one_cycle(4, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic.save(crit_new_checkpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit=None\n",
    "learn_gen=None\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=2e-5\n",
    "sz=192\n",
    "bs=5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
    "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Instructions:  \n",
    "Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
    "learn_gen.freeze_to(-1)\n",
    "learn.fit(1,lr)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
